import inference
import os
import subprocess
from threading import Timer
from pathlib import Path
import sys
import importlib

model = "gpt4o"
problem_description_nl = ""
possible_actions = ""

class State:
    def __init__(self, parent, id, action_path, action_taken, state_nl=None, reasoning=None,
                 diagram_encoding=None, diagram_code=None, diagram_picture=None):
        self.parent = parent
        self.id = id
        self.action_path = action_path
        self.action_taken = action_taken
        self.action_reasoning_nl = reasoning
        self.state_nl = state_nl
        self.child_states = []
        self.diagram_encoding = diagram_encoding
        self.diagram_code = diagram_code
        self.diagram_picture = diagram_picture
        self.active = True  # Active by default
        self.depth = parent.depth + 1 if parent else 0
        self.cost = 0
        self.action_path_reasoning = {} 
        self.count_added_to_backtrack = 0

def ini_g_diagrams(problem_name, problem_dir, problem_description, initial_state, goal_state, model, goal=False, params=None, error_message=None):
    if params is None:
        params = {}
    domain_name = problem_name.split("_instance")[0]
    diagram_temp = params.get('Diag_initial_temp', 0.0)
    attempts = 0
    Max_diag_attempts = params.get('Max_diag_attempts', 10)
    Diag_temp_incr = params.get('Diag_temp_incr', 0.1)
    state = goal_state if goal else initial_state
    print(f"getting diagram of {'goal' if goal else 'initial'} state")
    fallback_diagram_encoding = None
    fallback_diagram_code = None
    fallback_diagram_picture = None
    not_pass_tests = False
    previous_attempt = None

    # Generate diagram encoding
    while attempts < Max_diag_attempts:
        try:
            diagram_encoding = inference.generate_diagram_encoding_ini_g(
                problem_description, initial_state, goal_state, model, diagram_temp, domain_name, goal, previous_attempt, error_message
            )
            if fallback_diagram_encoding is None:
                fallback_diagram_encoding = diagram_encoding
            state.diagram_encoding = diagram_encoding

            validity, error_message = inference.test_diagram_encoding_ini_g(
                problem_description, initial_state, goal_state, model, goal
            )
            if validity:
                break
            else:
                record_diagram_attempt(problem_dir, error_message, state, attempts + 1)
                diagram_temp += Diag_temp_incr
                previous_attempt = diagram_encoding
        except Exception as e:
            error_message = str(e)
            record_diagram_attempt(problem_dir, error_message, state, attempts + 1)
            diagram_temp += Diag_temp_incr
            previous_attempt = diagram_encoding
        attempts += 1

    if attempts >= Max_diag_attempts and fallback_diagram_encoding is not None:
        state.diagram_encoding = fallback_diagram_encoding
        not_pass_tests = True
    elif attempts >= Max_diag_attempts:
        raise ValueError("Ran out of max attempts when generating diagram encoding")

    # Generate diagram code & picture
    attempts = 0
    diagram_temp = params.get('Diag_initial_temp', 0.0)
    previous_attempt = None
    error_message = None
    while attempts < Max_diag_attempts:
        try:
            save_path = f"{problem_dir}<PATH_REMOVED> + 1}.png"
            diagram_code = inference.generate_diagram_code_ini_g(
                domain_name, problem_description, initial_state, goal_state, model,
                diagram_temp,save_path, goal, previous_attempt, error_message
            )
            if fallback_diagram_code is None:
                fallback_diagram_code = diagram_code
                fallback_diagram_picture = save_path
            state.diagram_code = diagram_code

            error = execute_diagram_code(problem_dir, diagram_code, state.id, attempts + 1, save_path)
            state.diagram_picture = save_path
            if not error:
                validity, error_message = inference.test_diagram_ini_g(
                    problem_description, initial_state, goal_state, domain_name, model, goal
                )
                if validity:
                    if validity:
                        break
                    else:
                        record_diagram_attempt(problem_dir, error_message, state, attempts + 1)
                        diagram_temp += Diag_temp_incr
                        previous_attempt = diagram_code
                else:
                    break
            else:
                record_diagram_attempt(problem_dir, error_message, state, attempts + 1)
                diagram_temp += Diag_temp_incr
                previous_attempt = diagram_code
        except Exception as e:
            error_message = str(e)
            record_diagram_attempt(problem_dir, error_message, state, attempts + 1)
            diagram_temp += Diag_temp_incr
            previous_attempt = diagram_code
        attempts += 1

    if attempts >= Max_diag_attempts and fallback_diagram_code is not None:
        state.diagram_code = fallback_diagram_code
        state.diagram_picture = fallback_diagram_picture
        not_pass_tests = True
    elif attempts >= Max_diag_attempts:
        raise ValueError("Ran out of max attempts when generating diagram")

    store_diagram(problem_dir, state, not_pass_tests)
    return not_pass_tests

def setup(problem_name, problem_dir, params):
    if not os.path.exists(problem_dir):
        os.makedirs(problem_dir, exist_ok=True)

    # Read the problem description, initial state, and goal state from files
    with open(f"{problem_dir}<PATH_REMOVED>", 'r') as file:
        problem_description_nl = file.read()
    with open(f"{problem_dir}<PATH_REMOVED>", 'r') as file:
        initial_state_nl = file.read()
    with open(f"{problem_dir}<PATH_REMOVED>", 'r') as file:
        goal_state_nl = file.read()
    with open(f"{problem_dir}<PATH_REMOVED>", 'r') as file:
        possible_actions = file.read()

    initial_state = State(
        parent=None,
        id=0,
        action_path=[],
        action_taken=None,
        state_nl=initial_state_nl
    )
    goal_state = State(
        parent=None,
        id=9999,
        action_path=[],
        action_taken=None,
        state_nl=goal_state_nl
    )

    ini_g_diagrams(problem_name, problem_dir, problem_description_nl, initial_state, goal_state, model, False, params)
    ini_g_diagrams(problem_name, problem_dir, problem_description_nl, initial_state, goal_state, model, True, params)

    return initial_state, goal_state, problem_description_nl, possible_actions

def setup_child_state(problem_dir, child_state_id):
    # Create the folder structure for a new child state
    state_dir = f"{problem_dir}<PATH_REMOVED>"
    os.makedirs(state_dir, exist_ok=True)
    attempts = os.path.join(state_dir, "attempts")
    os.makedirs(attempts, exist_ok=True)
    child_attempts = os.path.join(attempts, "child_attempts")
    os.makedirs(child_attempts, exist_ok=True)
    diagram_attempts = os.path.join(attempts, "diagram_attempts")
    os.makedirs(diagram_attempts, exist_ok=True)

def record_state_ranking(problem_dir, ranked_states, depth, reasoning):
    ranking_dir = f"{problem_dir}<PATH_REMOVED>"
    if not os.path.exists(ranking_dir):
        os.makedirs(ranking_dir, exist_ok=True)
    ranking_file = f"{ranking_dir}<PATH_REMOVED>"
    with open(ranking_file, 'w') as file:
        file.write(f"Ranking at depth {depth}\n")
        file.write(f"Model's reasoning and response:\n{reasoning}\n")
        file.write("\nFinal ranking:\n")
        for idx, state in enumerate(ranked_states):
            file.write(f"Rank {idx + 1}: State ID {state.id}\n")
            file.write(f"State Description: {state.state_nl}\n")
            file.write("\n")

def beam_search(problem_name, problem_dir, pddl_module, k, params):
    global possible_actions
    global problem_description_nl

    initial_state, goal_state, problem_description_nl, possible_actions = setup(problem_name, problem_dir, params)
    initial_state.depth = 0
    initial_state.active = True
    open_list = [initial_state]
    last_id = 0
    depth = 0
    domain_name = problem_name.split("_instance")[0]

    # Dictionary to store last valid nodes, key is depth, value is list of nodes
    last_valid_nodes = {}

    while open_list or last_valid_nodes:
        if depth > params["Max_depth"]:
            print("Max depth reached. Restarting")
            return False

        print(f"Expanding nodes at depth {depth}")
        next_depth_states = []

        # Expand all active nodes at current depth
        for current_state in open_list:
            if not current_state.active:
                continue

            goal_reached = inference.check_goal_state(problem_description_nl, current_state, goal_state, initial_state, model)
            if goal_reached:
                print("Found goal!!")
                print(current_state.action_path)
                record_goal_state(problem_dir, current_state)
                pddl_plan = pddl_module.path_to_pddl_format(current_state.action_path, domain_name, problem_name, model)
                plan_pddl_path = os.path.join(problem_dir, 'plan.pddl')
                with open(plan_pddl_path, 'w') as f:
                    f.write(pddl_plan)
                return True

            if last_id < params['Max_id']:
                print(f"Expanding state {current_state.id}")
                children, last_id = generate_child_states(problem_dir, initial_state, current_state, last_id, goal_state, params, problem_name)
                record_child_states(problem_dir, current_state, children)
                # Generate diagrams/cost for child states
                #print(f"Generating diagrams and costs for child states of {current_state.id}")
                next_depth_states.extend(children) 
                    
            else:
                print("Generated max number of children! Max_id limit reached.")
                return False

        # Validate newly created states' action paths before ranking.
        """print(f"Validating action paths for new states at depth {depth + 1}...")
        valid_next_depth_states = []
        for st in next_depth_states:
            if not st.active:
                continue
            
            # Call your inference function that validates these actions
            is_valid_path, reasoning = inference.check_action_path(
                problem_description_nl,
                initial_state,
                st,
                goal_state,
                st.action_path,
                model
            )
            
            if not is_valid_path:
                print(f"State {st.id} failed the action path verification test: {reasoning}")
                st.active = False
                update_info_file_inactive(problem_dir, st, "Inactive: action path invalid")
            else:
                valid_next_depth_states.append(st)

        next_depth_states = valid_next_depth_states"""

        # Now do the ranking if we have more valid states than beam width
        num_nodes_at_next_depth = len([s for s in next_depth_states if s.active])
        if num_nodes_at_next_depth > k:
            print(f"Number of valid nodes at depth {depth + 1} ({num_nodes_at_next_depth}) exceeds beam width ({k}) -> ranking...")
            ranked_states, reasoning = inference.rank_states(
                next_depth_states, problem_description_nl, goal_state, model
            )
            for idx, node in enumerate(ranked_states):
                if idx < k:
                    node.active = True
                else:
                    node.active = False
                    update_info_file_inactive(problem_dir, node, "INACTIVE REASON: Pruned in beam search")
            record_state_ranking(problem_dir, ranked_states, depth + 1, reasoning)
            next_depth_states = ranked_states[:k]
        else:
            print(f"Number of valid nodes at depth {depth + 1} ({num_nodes_at_next_depth}) "
                  f"≤ beam width ({k}), skipping ranking/pruning.")

        open_list = next_depth_states
        
        #add the actions that passed all of the checks to the last_valid_nodes
        for state in next_depth_states:
            if state.count_added_to_backtrack < params["state_backtracking_limit"]:
                state.count_added_to_backtrack += 1
                node_depth = state.depth
                if node_depth not in last_valid_nodes:
                    last_valid_nodes[node_depth] = []
                if state not in last_valid_nodes[node_depth]:
                    last_valid_nodes[node_depth].append(state)
        
        if not open_list and last_valid_nodes:
            max_depth = max(last_valid_nodes.keys())
            print(f"No nodes in open_list, resuming search from last valid nodes at depth {max_depth}")
            for st in last_valid_nodes[max_depth]:
                st.count_added_to_backtrack += 1
            if last_valid_nodes[max_depth][0].count_added_to_backtrack <= params["state_backtracking_limit"]:
                open_list = last_valid_nodes[max_depth]
            else:
                open_list = last_valid_nodes.pop(max_depth)
            depth = max_depth 
            for s in open_list:
                s.child_states = []
                s.active = True
        else:
            depth += 1

    print("Search completed without finding a goal state.")
    return False

def record_child_states(problem_dir, current_state, child_states):
    child_states_dir = f"{problem_dir}<PATH_REMOVED>"
    os.makedirs(child_states_dir, exist_ok=True)
    child_states_file_path = f"{child_states_dir}<PATH_REMOVED>"
    with open(child_states_file_path, 'w') as file:
        file.write(f"Current State ID: {current_state.id}\n")
        file.write(f"Current State Description: {current_state.state_nl}\n")
        file.write("\n")
        for child in child_states:
            file.write(f"Child State ID: {child.id}\n")
            file.write(f"Action Taken: {child.action_taken}\n")
            file.write(f"State Description: {child.state_nl}\n")
            file.write("\n")

def record_goal_state(problem_dir, state):
    state_dir = f"{problem_dir}<PATH_REMOVED>"
    if not os.path.exists(state_dir):
        os.makedirs(state_dir, exist_ok=True)
    # Reconstruct the path from the initial state to the goal state
    path = []
    current = state
    while current is not None:
        path.append(current)
        current = current.parent
    path.reverse()

    # Write the goal state properties to the file
    file_path = os.path.join(state_dir, "found_goal_state_information.txt")
    with open(file_path, 'w') as file:
        file.write(f"State ID: {state.id}\n")
        file.write(f"Action Path: {state.action_path}\n")
        file.write(f"Action Taken: {state.action_taken}\n")
        file.write(f"State Description: {state.state_nl}\n")
        file.write(f"Action Reasoning: {state.action_reasoning_nl}\n")
        file.write(f"Diagram Encoding: {state.diagram_encoding}\n")
        file.write(f"Diagram Code: {state.diagram_code}\n")
        file.write(f"Diagram Picture Path: {state.diagram_picture}\n")
        file.write(f"Cost: {state.cost}\n")
        file.write("\n")

    # Copy the diagrams of the intermediate states into the goal state directory
    for idx, s in enumerate(path):
        if s.diagram_picture and os.path.exists(s.diagram_picture):
            diagram_path = f"{state_dir}<PATH_REMOVED>"
            from_path = f"{problem_dir}<PATH_REMOVED>"
            os.system(f"cp {from_path} {diagram_path}")

    print(f"Goal state and path diagrams recorded in {state_dir}")

def record_child_attempt(problem_dir, error_message, action_reasoning, new_state_nl, action_taken, nth_attempt, current_state):
    child_attempts_dir = f"{problem_dir}<PATH_REMOVED>"
    if not os.path.exists(child_attempts_dir):
        os.makedirs(child_attempts_dir, exist_ok=True)
    # Create the file for the nth attempt
    attempt_file_path = os.path.join(child_attempts_dir, f"child_attempt_{nth_attempt}.txt")
    with open(attempt_file_path, 'w') as file:
        file.write(f"Attempt Number: {nth_attempt}\n")
        file.write(f"Error Message: {error_message}\n\n")
        file.write("Action Reasoning:\n")
        file.write(f"{action_reasoning}\n\n")
        file.write("New State Description:\n")
        file.write(f"{new_state_nl}\n\n")
        file.write("Action Chosen:\n")
        file.write(f"{action_taken}\n")

def generate_child_states(problem_dir, initial_state, current_state, last_id, goal_state, params, problem_name):
    child_temp = params['Child_initial_temp']
    num_children = 0
    attempts = 0
    Max_child_attempts = params['Max_child_attempts']
    branching_factor = params['Max_child_attempts']
    Child_temp_incr = params['Child_temp_incr']
    Max_child_temp = params['Max_child_temp']
    previous_attempt = None
    error_message = None
    
    #last_n = params["action_reasoning_num"]
    #last_few_action_reasoning = list(current_state.action_path_reasoning.items())[-last_n:]
    
    while num_children < branching_factor and attempts < Max_child_attempts:
        try:
            if child_temp > Max_child_temp:
                child_temp = Max_child_temp

            else:
                chosen_actions = []
                for child in current_state.child_states:
                    if child.active == True:
                        chosen_actions.append(child.action_taken)

            attempts += 1
            print(f"attempt number {attempts} for getting child states (temp = {child_temp})")
            action_reasoning, new_state_nl, action_taken = inference.next_action(
                problem_description_nl, possible_actions, initial_state, current_state, goal_state, model,
                child_temp, chosen_actions, previous_attempt, error_message
            )
            if action_taken == "":
                num_attempts += 1
                continue
            
            is_unique = inference.is_unique_action(
                problem_description_nl, current_state, action_taken, new_state_nl, model
            )

            if (not current_state.child_states) or is_unique:
                last_id += 1
                new_child = State(
                    id=last_id,
                    parent=current_state,
                    action_path=current_state.action_path + [action_taken],
                    action_taken=action_taken,
                    reasoning=action_reasoning,
                    state_nl=new_state_nl
                )
                new_child.action_path_reasoning = dict(current_state.action_path_reasoning)
                new_child.action_path_reasoning[f"action {len(current_state.action_path)}:" + action_taken] = action_reasoning
                new_child.depth = current_state.depth + 1
                new_child.cost = current_state.cost + 1
                
                # Make folders and info file for this new child state
                setup_child_state(problem_dir, new_child.id)
                create_info_file(problem_dir, new_child, params)
                print(f"created the {num_children+1}th child state (temp = {child_temp})")
                
                previous_attempt = None
                error_message = None
                
                not_pass_test, error_message_action = generate_diagrams(problem_name.split("_instance")[0], problem_dir, problem_description_nl, new_child, initial_state, model, params, goal_state)
                
                if not_pass_test:
                    print(f"Child state {new_child.id} is invalid: {error_message_action}")
                    new_child.active = False
                    update_info_file_inactive(problem_dir, new_child, f"INACTIVE REASON: deactivated because diagram generation failed or invalid action chosen: {error_message_action}", None)
                    previous_attempt = new_child.action_taken
                    error_message = error_message_action
                else:
                     # Call inference function that validates action paths
                    is_valid_path, reasoning = inference.check_action_path(
                        problem_description_nl,
                        initial_state,
                        new_child,
                        goal_state,
                        new_child.action_path,
                        possible_actions,
                        model
                    )
                    
                    if not is_valid_path:
                        print(f"State {new_child.id} failed the action path verification test: {reasoning}")
                        new_child.active = False
                        update_info_file_inactive(problem_dir, new_child, f"INACTIVE REASON: action path invalid: {reasoning}")
                    else:
                        print(f"successfully verified the action of new state {new_child.id} with the path: {reasoning}")
                        update_info_file_inactive(problem_dir, new_child, f"successfully verified the action of new state {new_child.id} with the path: {reasoning}")
                        current_state.child_states.append(new_child)
                        num_children += 1
                               
                
            else:
                error_message = "Action not unique compared to other actions chosen from this state before. Choose a new action"
                print("new action not unique")
                record_child_attempt(problem_dir, error_message, action_reasoning, new_state_nl, action_taken, attempts, current_state)
                previous_attempt = action_taken


        except Exception as e:
            #error_message = str(e)
            print(e)
            record_child_attempt(problem_dir, error_message, action_reasoning, new_state_nl, action_taken, attempts, current_state)
            #previous_attempt = action_taken
        
        child_temp += Child_temp_incr


    return current_state.child_states, last_id

def record_diagram_attempt(problem_dir, error_message, state, nth_attempt):
    diagram_attempts_dir = f"{problem_dir}<PATH_REMOVED>"
    if not os.path.exists(diagram_attempts_dir):
        os.makedirs(diagram_attempts_dir, exist_ok=True)
    attempt_file_path = os.path.join(diagram_attempts_dir, f"diagram_attempt_{nth_attempt}.txt")
    with open(attempt_file_path, 'w') as file:
        file.write(f"Attempt Number: {nth_attempt}\n")
        file.write(f"Error Message: {error_message}\n\n")
        file.write("Diagram Encoding:\n")
        file.write(f"{state.diagram_encoding}\n\n")
        file.write("Diagram Code:\n")
        file.write(f"{state.diagram_code}\n")

def generate_diagrams(domain_name, problem_dir, problem_description, child_state, initial_state, model, params, goal_state):
    max_attempts          = params['Max_diag_attempts']
    #enc_temp_incr         = params['Diag_temp_incr']
    #code_temp_incr        = params['Diag_temp_incr']
    #diagram_encoding_temp = params['Diag_initial_temp']  
    #diagram_code_temp     = params['Diag_initial_temp'] 
    #diag_max_temp = params['Max_diag_temp'] 

    print(f"Getting diagram of state {child_state.id}")

    not_pass_tests = False

    # Keep track of the “previous attempt” inputs for each generation function
    previous_encoding = None
    previous_code     = None
    error_message_encoding = None
    error_message_action = None

    # We will use a single loop but store “which step” we are on:
    #   step = "encoding" → generate & test encoding
    #   step = "code"     → generate code, execute it, then test the diagram
    step = "encoding"

    attempts = 0
    while attempts < max_attempts or child_state.diagram_picture == None: #attempts < max_attempts:
        #if diagram_code_temp > diag_max_temp: diagram_code_temp = diag_max_temp
        #if diagram_encoding_temp > diag_max_temp: diagram_encoding_temp = diag_max_temp
        
        if step == "encoding" and attempts < max_attempts:
                
            # 1) Generate & test the diagram encoding
            try:
                print(f"[Attempt {attempts+1}] Generating diagram encoding")
                diagram_encoding = inference.generate_diagram_encoding(
                    problem_description,
                    initial_state,
                    child_state,
                    model,
                    0.0,
                    previous_encoding,
                    error_message_encoding
                )
                
                attempts += 1

                # Test the encoding
                child_state.diagram_encoding = diagram_encoding
                validity, error_message_encoding = inference.test_diagram_encoding(
                    problem_description, initial_state, child_state, model
                )

                if not validity:
                    print(f"Encoding invalid: {error_message_encoding}")
                    record_diagram_attempt(problem_dir, error_message_encoding, child_state, attempts+1)

                    previous_encoding = diagram_encoding

                    # Stay in “encoding” step for the next iteration
                    continue
                else:
                    # We have a valid encoding
                    print("Diagram encoding is valid.")
                                        
                    # Move immediately to code step in the same iteration
                    previous_encoding = None
                    error_message_encoding = None
                    step = "code"

            except Exception as e_enc:
                #error_message_encoding = str(e_enc) #the system error is probably not rlelevant tot he model's output
                #previous_encoding = diagram_encoding
                print(f"Exception in encoding step: {error_message_encoding}")
                record_diagram_attempt(problem_dir, error_message_encoding, child_state, attempts+1)

                # We remain in the “encoding” step for the next iteration
                continue

        # If we already have a valid encoding, proceed to code generation & test
        #if step == "code" and have_valid_encoding:
        else:
            try:
                # 2) Generate diagram code
                print(f"[Attempt {attempts+1}] Generating diagram code")
                save_path = f"{problem_dir}<PATH_REMOVED>"

                diagram_code = inference.generate_diagram_code(
                    domain_name,
                    problem_description,
                    child_state,
                    initial_state,
                    model,
                    0.0,
                    save_path,
                    previous_code,
                    error_message_action
                )
                attempts += 1

                child_state.diagram_code = diagram_code

                # 3) Attempt to execute diagram code
                error_exec = execute_diagram_code(problem_dir, diagram_code, child_state.id, attempts+1, save_path)

                if error_exec:
                    # Execution failed → increment code temperature
                    error_message_action = f"Code execution failed: {error_exec}"
                    print(error_message_action)
                    record_diagram_attempt(problem_dir, error_message_action, child_state, attempts+1)

                    #diagram_code_temp += code_temp_incr
                    previous_code      = diagram_code
                    # We do NOT invalidate the encoding, so remain in “code” step
                    continue
                
                child_state.diagram_picture = save_path

                # Now that the child's diagram is generated, verify the action again
                validity_action, error_message_action = inference.check_action_validity(
                    problem_description_nl,
                    child_state.parent,            # parent state
                    child_state.action_taken,
                    child_state.state_nl,
                    child_state.action_reasoning_nl,
                    child_state,
                    goal_state,
                    possible_actions,
                    initial_state,
                    model
                )

                if not validity_action:
                    print(f"Diagram test failed: {error_message_action}")
                    record_diagram_attempt(problem_dir, error_message_action, child_state, attempts+1)
                    #  start over completely if final test fails,
                    # which means discarding/invalidating the encoding
                    #step = "encoding"
                    #previous_code = diagram_code

                else:
                    # All steps succeeded
                    print(f"Successfully generated & tested diagram for state {child_state.id}")
                    error_message_action = None
                break

            except Exception as e_code:
                error_message = str(e_code) #system errors should not be reported to the model
                print(f"Exception in code step: {error_message}")
                record_diagram_attempt(problem_dir, error_message_action, child_state, attempts+1)
                # Bump code temperature
                #diagram_code_temp += code_temp_incr
                #previous_code = diagram_code
                # Remain in “code” step
                continue

    # If we exit the loop normally, check if we actually succeeded or ran out of tries
    if attempts >= max_attempts or error_message_action != None:
        print("Reached maximum attempts without a fully valid diagram.")
        not_pass_tests = True

    # Finally, store or record the results
    store_diagram(problem_dir, child_state, not_pass_tests)
    return not_pass_tests, error_message_action
    

def execute_diagram_code(problem_dir, diagram_code, child_state_id, nth_attempt, save_path):
    diagram_attempts_dir = f"{problem_dir}<PATH_REMOVED>"
    os.makedirs(diagram_attempts_dir, exist_ok=True)
    diagram_code_path = os.path.join(diagram_attempts_dir, f"diagram_code_{nth_attempt}.py")
    with open(diagram_code_path, 'w') as file:
        file.write(diagram_code)

    cmd = ['python', diagram_code_path]
    ping = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    cancelled = False

    def kill(process):
        process.kill()

    my_timer = Timer(10, kill, [ping])
    try:
        my_timer.start()
        stdout, stderr = ping.communicate()
    finally:
        my_timer.cancel()

    if stderr.decode('ascii'):
        cancelled = True

    if cancelled:
        file_path = Path(save_path)
        if file_path.is_file() and file_path.stat().st_size == 0:
            return "Execution timed out"
        elif stderr:
            return stderr.decode('ascii')
    return None

def store_diagram(problem_dir, child_state, not_pass_tests):
    state_dir = f"{problem_dir}<PATH_REMOVED>"
    if not os.path.exists(state_dir):
        os.makedirs(state_dir, exist_ok=True)
    # Store the diagram picture
    if child_state.diagram_picture and os.path.exists(child_state.diagram_picture):
        diagram_path = os.path.join(state_dir, "diagram.png")
        os.rename(child_state.diagram_picture, diagram_path)
        child_state.diagram_picture = diagram_path

    # Store the diagram code
    if child_state.diagram_code:
        diagram_code_path = os.path.join(state_dir, "diagram_code.py")
        with open(diagram_code_path, 'w') as file:
            file.write(child_state.diagram_code)

    # Store the diagram encoding
    if child_state.diagram_encoding:
        diagram_encoding_path = os.path.join(state_dir, "diagram_encoding.txt")
        with open(diagram_encoding_path, 'w') as file:
            file.write(child_state.diagram_encoding)

    # Record if the diagram did not pass the test but was stored due to max attempts
    if not_pass_tests:
        with open(os.path.join(state_dir, "diagram_status.txt"), 'w') as file:
            file.write("This diagram did not pass the test but was stored due to running out of attempts.\n")

def get_last_valid_node(current_state, invalid_action_index):
    if invalid_action_index is None:
        # Cannot determine the invalid action; return None
        return None
    steps_to_go_back = invalid_action_index + 1  # as indices start from 0
    last_valid_node = current_state
    for _ in range(steps_to_go_back):
        if last_valid_node.parent:
            last_valid_node = last_valid_node.parent
        else:
            # Reached the initial state or no valid actions
            return None
    return last_valid_node

def move_nodes_to_tries(problem_dir):
    import shutil
    tries_dir = os.path.join(problem_dir, 'tries')
    # Remove 'tries' folder if it exists
    if os.path.exists(tries_dir):
        shutil.rmtree(tries_dir)
    os.makedirs(tries_dir, exist_ok=True)
    # Move all state folders into 'tries'
    for item in os.listdir(problem_dir):
        item_path = os.path.join(problem_dir, item)
        if os.path.isdir(item_path):
            if item.startswith('state_'):
                shutil.move(item_path, tries_dir)

def reset_problem_directory(problem_dir):
    import shutil
    # Remove all state folders
    for item in os.listdir(problem_dir):
        item_path = os.path.join(problem_dir, item)
        if os.path.isdir(item_path):
            if item.startswith('state_'):
                shutil.rmtree(item_path)
    # remove other search-related files (e.g., ranking files)
    ranking_dir = os.path.join(problem_dir, 'ranking')
    if os.path.exists(ranking_dir):
        shutil.rmtree(ranking_dir)

def create_info_file(problem_dir, state, params):
    info_file_path = f"{problem_dir}<PATH_REMOVED>"
    with open(info_file_path, 'w') as f:
        parent_id = state.parent.id if state.parent else None
        f.write(f"Parent State ID: {parent_id}\n")
        f.write(f"State Depth: {state.depth}\n")
        f.write(f"Action Taken: {state.action_taken}\n")
        f.write(f"Action Reasoning: {state.action_reasoning_nl}\n")
        f.write(f"State Description: {state.state_nl}\n")
        f.write(f"Action Path: {state.action_path}\n")
        #last_n = params["action_reasoning_num"]
        #last_few_action_reasoning = list(state.action_path_reasoning.items())[-last_n:]
        #f.write(f"last few Actions and reasoning: {last_few_action_reasoning}\n")

def update_info_file_inactive(problem_dir, state, reason, invalid_action_idx=None):
    info_file_path = f"{problem_dir}<PATH_REMOVED>"
    # Only update if the file already exists (i.e., the state folder was created)
    if os.path.exists(info_file_path):
        with open(info_file_path, 'a') as f:
            f.write(f"\n{reason}\n")



def main():
    # Require at least 3 args: problem_name, problem_dir, plus optional overrides
    if len(sys.argv) < 3:
        print("Usage: python search.py <problem_name> <problem_dir> "
              "[--Child_initial_temp=FLOAT] [--Beam_width=INT] [--Max_child_attempts=INT] ...")
        sys.exit(1)

    problem_name = sys.argv[1]  
    problem_dir = sys.argv[2]   
    
    # Default parameters
    params = {
        'Child_initial_temp': 0.5,
        'Beam_width': 4,
        'Max_child_attempts': 4,
        'Child_temp_incr': 0.4,
        'Max_child_temp': 1.2,
        'Max_diag_attempts': 4,
        'Max_id': 120,
        'Max_depth': 28,
        "state_backtracking_limit": 2,
    }

    optional_args = sys.argv[3:]

    def parse_value(key, val_str):
        default_val = params[key]
        if isinstance(default_val, int):
            return int(val_str)
        elif isinstance(default_val, float):
            return float(val_str)
        return val_str  # fallback if needed

    for arg in optional_args:
        if arg.startswith('--') and '=' in arg:
            key_val = arg[2:].split('=', 1)  # remove leading '--'
            if len(key_val) == 2:
                key, val_str = key_val
                if key in params:
                    params[key] = parse_value(key, val_str)
                else:
                    print(f"[WARNING] Unknown parameter '{key}' was provided.")
            else:
                print(f"[WARNING] Could not parse '{arg}' as key=value.")
        else:
            print(f"[WARNING] Ignoring unrecognized argument '{arg}'.")

    success = False
    k = params["Beam_width"]  
    
    #import PDDL_translations
    domain_name = problem_name.split("_instance")[0]
    module_path = f'.<PATH_REMOVED>'
    module_name = 'PDDL_tranlation'
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    pddl_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(pddl_module)
    
    success = beam_search(problem_name, problem_dir, pddl_module, k, params)

    if success:
        print("Goal state found.")
            
if __name__ == "__main__":
    main()